Open In Colab

Introduction¶

This article is created by Chris to illustrate a few different methods for optimizing parameters using the loss function, starting from traditional gradient descent to other alternative approaches). d2l and mxnet need to be installed first before proceeding.

This tutorial is about optimization in machine learning and deep learning. Assume that we already have a loss function $l(w)$ and we want to minimize it. If you need to maximize a function, simply flip the sign on the function for minimization optimization. $l(w)$ is continuouly differentiable real-valued function and $w\in \mathbb{R}^d$ and $d=1,2,..$

1. Gradient Descent¶

1.1 In 1-Dimension¶

For optimization function, $l(w)$, with Taylor expansion, we have $l(w+\epsilon)=l(w)+\epsilon l'(w) + \Omega(\epsilon^2)$.

Since we want to minize $l(w)$, we can set $\epsilon=-\eta l'(w)$ where $\eta>0$, and then we have $l(x-\eta l'(w))=l(w)-\eta l'^2(w) + \Omega(\eta^2 l'^2(w))$.

That is, $l(w-\eta l'(w))<=l(w)$. Thus, to iterate $w$, we can minimize $l(w)$. $w$ is updated by $w-\eta l'(x)$ where $\eta is the learning rate$.

For example, $l(w)=w^2$ and the minimized value of $l(w)=0$ at $w=0$. How to use gradient descent to solve the minization?

In [ ]:
!pip install d2l==0.11.4
Collecting d2l
  Downloading https://files.pythonhosted.org/packages/30/2b/3515cd6f2898bf95306a5c58b065aeb045fdc25516f2b68b0f8409e320c3/d2l-0.16.1-py3-none-any.whl (76kB)
     |████████████████████████████████| 81kB 3.9MB/s 
Requirement already satisfied: matplotlib in /usr/local/lib/python3.7/dist-packages (from d2l) (3.2.2)
Requirement already satisfied: numpy in /usr/local/lib/python3.7/dist-packages (from d2l) (1.19.5)
Requirement already satisfied: pandas in /usr/local/lib/python3.7/dist-packages (from d2l) (1.1.5)
Requirement already satisfied: jupyter in /usr/local/lib/python3.7/dist-packages (from d2l) (1.0.0)
Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->d2l) (2.4.7)
Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.7/dist-packages (from matplotlib->d2l) (0.10.0)
Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->d2l) (1.3.1)
Requirement already satisfied: python-dateutil>=2.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->d2l) (2.8.1)
Requirement already satisfied: pytz>=2017.2 in /usr/local/lib/python3.7/dist-packages (from pandas->d2l) (2018.9)
Requirement already satisfied: jupyter-console in /usr/local/lib/python3.7/dist-packages (from jupyter->d2l) (5.2.0)
Requirement already satisfied: ipykernel in /usr/local/lib/python3.7/dist-packages (from jupyter->d2l) (4.10.1)
Requirement already satisfied: nbconvert in /usr/local/lib/python3.7/dist-packages (from jupyter->d2l) (5.6.1)
Requirement already satisfied: ipywidgets in /usr/local/lib/python3.7/dist-packages (from jupyter->d2l) (7.6.3)
Requirement already satisfied: qtconsole in /usr/local/lib/python3.7/dist-packages (from jupyter->d2l) (5.0.2)
Requirement already satisfied: notebook in /usr/local/lib/python3.7/dist-packages (from jupyter->d2l) (5.3.1)
Requirement already satisfied: six in /usr/local/lib/python3.7/dist-packages (from cycler>=0.10->matplotlib->d2l) (1.15.0)
Requirement already satisfied: ipython in /usr/local/lib/python3.7/dist-packages (from jupyter-console->jupyter->d2l) (5.5.0)
Requirement already satisfied: pygments in /usr/local/lib/python3.7/dist-packages (from jupyter-console->jupyter->d2l) (2.6.1)
Requirement already satisfied: prompt-toolkit<2.0.0,>=1.0.0 in /usr/local/lib/python3.7/dist-packages (from jupyter-console->jupyter->d2l) (1.0.18)
Requirement already satisfied: jupyter-client in /usr/local/lib/python3.7/dist-packages (from jupyter-console->jupyter->d2l) (5.3.5)
Requirement already satisfied: tornado>=4.0 in /usr/local/lib/python3.7/dist-packages (from ipykernel->jupyter->d2l) (5.1.1)
Requirement already satisfied: traitlets>=4.1.0 in /usr/local/lib/python3.7/dist-packages (from ipykernel->jupyter->d2l) (5.0.5)
Requirement already satisfied: nbformat>=4.4 in /usr/local/lib/python3.7/dist-packages (from nbconvert->jupyter->d2l) (5.1.2)
Requirement already satisfied: jinja2>=2.4 in /usr/local/lib/python3.7/dist-packages (from nbconvert->jupyter->d2l) (2.11.3)
Requirement already satisfied: testpath in /usr/local/lib/python3.7/dist-packages (from nbconvert->jupyter->d2l) (0.4.4)
Requirement already satisfied: defusedxml in /usr/local/lib/python3.7/dist-packages (from nbconvert->jupyter->d2l) (0.6.0)
Requirement already satisfied: bleach in /usr/local/lib/python3.7/dist-packages (from nbconvert->jupyter->d2l) (3.3.0)
Requirement already satisfied: pandocfilters>=1.4.1 in /usr/local/lib/python3.7/dist-packages (from nbconvert->jupyter->d2l) (1.4.3)
Requirement already satisfied: jupyter-core in /usr/local/lib/python3.7/dist-packages (from nbconvert->jupyter->d2l) (4.7.1)
Requirement already satisfied: entrypoints>=0.2.2 in /usr/local/lib/python3.7/dist-packages (from nbconvert->jupyter->d2l) (0.3)
Requirement already satisfied: mistune<2,>=0.8.1 in /usr/local/lib/python3.7/dist-packages (from nbconvert->jupyter->d2l) (0.8.4)
Requirement already satisfied: widgetsnbextension~=3.5.0 in /usr/local/lib/python3.7/dist-packages (from ipywidgets->jupyter->d2l) (3.5.1)
Requirement already satisfied: jupyterlab-widgets>=1.0.0; python_version >= "3.6" in /usr/local/lib/python3.7/dist-packages (from ipywidgets->jupyter->d2l) (1.0.0)
Requirement already satisfied: qtpy in /usr/local/lib/python3.7/dist-packages (from qtconsole->jupyter->d2l) (1.9.0)
Requirement already satisfied: pyzmq>=17.1 in /usr/local/lib/python3.7/dist-packages (from qtconsole->jupyter->d2l) (22.0.3)
Requirement already satisfied: ipython-genutils in /usr/local/lib/python3.7/dist-packages (from qtconsole->jupyter->d2l) (0.2.0)
Requirement already satisfied: terminado>=0.8.1 in /usr/local/lib/python3.7/dist-packages (from notebook->jupyter->d2l) (0.9.2)
Requirement already satisfied: Send2Trash in /usr/local/lib/python3.7/dist-packages (from notebook->jupyter->d2l) (1.5.0)
Requirement already satisfied: pickleshare in /usr/local/lib/python3.7/dist-packages (from ipython->jupyter-console->jupyter->d2l) (0.7.5)
Requirement already satisfied: decorator in /usr/local/lib/python3.7/dist-packages (from ipython->jupyter-console->jupyter->d2l) (4.4.2)
Requirement already satisfied: pexpect; sys_platform != "win32" in /usr/local/lib/python3.7/dist-packages (from ipython->jupyter-console->jupyter->d2l) (4.8.0)
Requirement already satisfied: simplegeneric>0.8 in /usr/local/lib/python3.7/dist-packages (from ipython->jupyter-console->jupyter->d2l) (0.8.1)
Requirement already satisfied: setuptools>=18.5 in /usr/local/lib/python3.7/dist-packages (from ipython->jupyter-console->jupyter->d2l) (54.0.0)
Requirement already satisfied: wcwidth in /usr/local/lib/python3.7/dist-packages (from prompt-toolkit<2.0.0,>=1.0.0->jupyter-console->jupyter->d2l) (0.2.5)
Requirement already satisfied: jsonschema!=2.5.0,>=2.4 in /usr/local/lib/python3.7/dist-packages (from nbformat>=4.4->nbconvert->jupyter->d2l) (2.6.0)
Requirement already satisfied: MarkupSafe>=0.23 in /usr/local/lib/python3.7/dist-packages (from jinja2>=2.4->nbconvert->jupyter->d2l) (1.1.1)
Requirement already satisfied: packaging in /usr/local/lib/python3.7/dist-packages (from bleach->nbconvert->jupyter->d2l) (20.9)
Requirement already satisfied: webencodings in /usr/local/lib/python3.7/dist-packages (from bleach->nbconvert->jupyter->d2l) (0.5.1)
Requirement already satisfied: ptyprocess; os_name != "nt" in /usr/local/lib/python3.7/dist-packages (from terminado>=0.8.1->notebook->jupyter->d2l) (0.7.0)
Installing collected packages: d2l
Successfully installed d2l-0.16.1
In [ ]:
!pip install mxnet
Collecting mxnet
  Downloading https://files.pythonhosted.org/packages/64/20/76af36cad6754a15f39d3bff19e09921dec72b85261e455d4edc50ebffa8/mxnet-1.7.0.post2-py2.py3-none-manylinux2014_x86_64.whl (54.7MB)
     |████████████████████████████████| 54.7MB 70kB/s 
Requirement already satisfied: numpy<2.0.0,>1.16.0 in /usr/local/lib/python3.7/dist-packages (from mxnet) (1.19.5)
Collecting graphviz<0.9.0,>=0.8.1
  Downloading https://files.pythonhosted.org/packages/53/39/4ab213673844e0c004bed8a0781a0721a3f6bb23eb8854ee75c236428892/graphviz-0.8.4-py2.py3-none-any.whl
Requirement already satisfied: requests<3,>=2.20.0 in /usr/local/lib/python3.7/dist-packages (from mxnet) (2.23.0)
Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests<3,>=2.20.0->mxnet) (1.24.3)
Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests<3,>=2.20.0->mxnet) (2020.12.5)
Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests<3,>=2.20.0->mxnet) (3.0.4)
Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests<3,>=2.20.0->mxnet) (2.10)
Installing collected packages: graphviz, mxnet
  Found existing installation: graphviz 0.10.1
    Uninstalling graphviz-0.10.1:
      Successfully uninstalled graphviz-0.10.1
Successfully installed graphviz-0.8.4 mxnet-1.7.0.post2
In [ ]:
import d2l
import math
from mxnet import np, npx
npx.set_np()
In [ ]:
def l(w):
    return w**2  # Minimization function

def gradl(w):
    return 2 * w  # The derivative
In [ ]:
def gd(eta, epoches=10):
    w = 10 # Initialize value is 10
    results = [w]
    for i in range(epoches):
        w -= eta * gradl(w)
        results.append(w)
    print('epoch={}, w={}'.format(epoches,w))
    return results
In [ ]:
gd(0.2)
epoch=10, w=0.06046617599999997
Out[ ]:
[10,
 6.0,
 3.5999999999999996,
 2.1599999999999997,
 1.2959999999999998,
 0.7775999999999998,
 0.46655999999999986,
 0.2799359999999999,
 0.16796159999999993,
 0.10077695999999996,
 0.06046617599999997]
In [ ]:
def show_trace(res):
    n = max(abs(min(res)), abs(max(res)))
    l_line = np.arange(-n, n, 0.01)
    d2l.set_figsize((3.5, 2.5))
    d2l.plot([l_line, res], [[l(w) for w in l_line], [l(w) for w in res]],
             'w', 'l(w)', fmts=['-', '-o'])
In [ ]:
show_trace(gd(0.2))
epoch=10, w=0.06046617599999997
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-8-96884201c76d> in <module>()
----> 1 show_trace(gd(0.2))

<ipython-input-7-acf110ce04c6> in show_trace(res)
      2     n = max(abs(min(res)), abs(max(res)))
      3     l_line = np.arange(-n, n, 0.01)
----> 4     d2l.set_figsize((3.5, 2.5))
      5     d2l.plot([l_line, res], [[l(w) for w in l_line], [l(w) for w in res]],
      6              'w', 'l(w)', fmts=['-', '-o'])

AttributeError: module 'd2l' has no attribute 'set_figsize'
In [ ]:
show_trace(gd(0.2,100))
epoch=100, w=6.533186235000684e-22
In [ ]:
show_trace(gd(0.01,100))
epoch=100, w=1.3261955589475318
In [ ]:
show_trace(gd(1.1,10))
epoch=10, w=61.917364224000096

1.2 Local minima¶

One issue in machine learning and deep learning is local minima. The demo here is $l(w)=xcos(cw)$ which have many local minima.

In [ ]:
def annotate(text, xy, xytext):
    d2l.plt.gca().annotate(text, xy=xy, xytext=xytext,
                           arrowprops=dict(arrowstyle='->'))

def l(w): return w * np.cos(np.pi * w)
w = np.arange(-1.0, 2.0, 0.01)
d2l.plot(w, [l(w), ], 'w', 'l(w)')
annotate('local minimum', (-0.3, -0.25), (-0.77, -1.0))
annotate('global minimum in in this plot', (1.1, -0.95), (0.6, 0.8))
In [ ]:
def l(w):
    c = 0.15 * np.pi
    return w * np.cos(c * w)

def gradl(w):
    c = 0.15 * np.pi
    return np.cos(c * w) - c * w * np.sin(c * w)
In [ ]:
show_trace(gd(2,30))
epoch=30, w=-1.721549970478682
In [ ]:
show_trace(gd(1,30))
epoch=30, w=7.270386695790638

1.3 Multivariate gradient descent¶

In $l(w)$, $w$ is a vector and $l$ is a function to map a vector to a scalar. For example, $l(w)=w^2_1+2w^2_2$ where $w=[w_1, w_2]$. We know the minimum of $l(w)$ is 0 with $w=[0,0]$.

In [ ]:
def l_2(w1, w2):
    return w1 ** 2 + 2 * w2 ** 2  # Objective

def gradl_2(w1, w2):
    return (2 * w1, 4 * w2)  # Gradient

def gd_2(w1, w2, s1, s2, eta=0.1, beta=0.5):
    (g1, g2) = gradl_2(w1, w2)  # Compute gradient
    return (w1 - eta * g1, w2 - eta * g2, 0, 0)  # Update variables
In [ ]:
def gd_2d(_gd, epoches=20, eta=0.1, output_det=0, beta=0.5):
    """Optimize a 2-dim objective function with a customized trainer."""
    w1, w2, s1, s2 = -5, -2, 0, 0
    results = [(w1, w2)]
    for i in range(epoches):
        w1, w2, s1, s2 = _gd(w1, w2, s1, s2, eta, beta)
        results.append((w1, w2))
        if output_det>0:
            print('epoch %d, w1 %f, w2 %f' % (i + 1, w1, w2))
    print('epoch %d, w1 %f, w2 %f' % (i + 1, w1, w2))
    return results

def show_trace_2d(l, results):
    """Show the trace of 2D variables during optimization."""
    d2l.set_figsize((3.5, 2.5))
    d2l.plt.plot(*zip(*results), '-o', color='#ff7f0e')
    x1, x2 = np.meshgrid(np.arange(-5.5, 1.0, 0.1), np.arange(-3.0, 1.0, 0.1))
    d2l.plt.contour(x1, x2, l(x1, x2), colors='#1f77b4')
    d2l.plt.xlabel('w1')
    d2l.plt.ylabel('w2')
In [ ]:
show_trace_2d(l_2, gd_2d(gd_2,5,0.1))
epoch 5, w1 -1.638400, w2 -0.155520
In [ ]:
show_trace_2d(l_2, gd_2d(gd_2,10,0.1))
epoch 10, w1 -0.536871, w2 -0.012093
In [ ]:
show_trace_2d(l_2, gd_2d(gd_2,10,0.2))
epoch 10, w1 -0.030233, w2 -0.000000

2. Stochastic gradient descent¶

The gradient descent above is for a single instance. In big data, the number of training data would be $N$, and thus, the optimization function would be $l(w)=\frac{1}{N}\sum_{i=1}^{N}l_i(x_i,w)$, and the gradient of the optimization function is $\Delta l(w)=\frac{1}{N}\sum_{i=1}^{N}\Delta l_i(x_i,w)$.

Thus, when the training data contains $N$ instances, the computing cost is $O(N)$, which might be too high. In Stochastic gradient descent, $w$ is updated by $w - \eta\Delta l_i(w)$ rather than by $w - \eta\frac{1}{N}\sum_{i=1}^{N}\Delta l_i(x_i,w)$, where $i$ is uniformly and randomly selected from $1, ..., N$.

In [ ]:
 

3. Momentum¶

To understand the importance of momentum, we firstly discuss a wierd optimization function for gradient descent.

$l(w) = 0.1w_1^2 + 2w_2^2$

where the minimum is 0 at (0,0).

What happens if we use gradient descent for this $l(w)$?

In [ ]:
def l_2i(w1, w2):
    return 0.1 * w1 ** 2 + 2 * w2 ** 2
def gd_2i(w1, w2, s1, s2, eta=0.1, beta=0.5):
    return (w1 - eta * 0.2 * w1, w2 - eta * 4 * w2, 0, 0)
In [ ]:
show_trace_2d(l_2i, gd_2d(gd_2i,20,0.4,1))
epoch 1, w1 -4.600000, w2 1.200000
epoch 2, w1 -4.232000, w2 -0.720000
epoch 3, w1 -3.893440, w2 0.432000
epoch 4, w1 -3.581965, w2 -0.259200
epoch 5, w1 -3.295408, w2 0.155520
epoch 6, w1 -3.031775, w2 -0.093312
epoch 7, w1 -2.789233, w2 0.055987
epoch 8, w1 -2.566094, w2 -0.033592
epoch 9, w1 -2.360807, w2 0.020155
epoch 10, w1 -2.171942, w2 -0.012093
epoch 11, w1 -1.998187, w2 0.007256
epoch 12, w1 -1.838332, w2 -0.004354
epoch 13, w1 -1.691265, w2 0.002612
epoch 14, w1 -1.555964, w2 -0.001567
epoch 15, w1 -1.431487, w2 0.000940
epoch 16, w1 -1.316968, w2 -0.000564
epoch 17, w1 -1.211611, w2 0.000339
epoch 18, w1 -1.114682, w2 -0.000203
epoch 19, w1 -1.025507, w2 0.000122
epoch 20, w1 -0.943467, w2 -0.000073
epoch 20, w1 -0.943467, w2 -0.000073
In [ ]:
show_trace_2d(l_2i, gd_2d(gd_2i,20,0.6,1))
epoch 1, w1 -4.400000, w2 2.800000
epoch 2, w1 -3.872000, w2 -3.920000
epoch 3, w1 -3.407360, w2 5.488000
epoch 4, w1 -2.998477, w2 -7.683200
epoch 5, w1 -2.638660, w2 10.756480
epoch 6, w1 -2.322020, w2 -15.059072
epoch 7, w1 -2.043378, w2 21.082701
epoch 8, w1 -1.798173, w2 -29.515781
epoch 9, w1 -1.582392, w2 41.322094
epoch 10, w1 -1.392505, w2 -57.850931
epoch 11, w1 -1.225404, w2 80.991303
epoch 12, w1 -1.078356, w2 -113.387825
epoch 13, w1 -0.948953, w2 158.742955
epoch 14, w1 -0.835079, w2 -222.240137
epoch 15, w1 -0.734869, w2 311.136191
epoch 16, w1 -0.646685, w2 -435.590668
epoch 17, w1 -0.569083, w2 609.826935
epoch 18, w1 -0.500793, w2 -853.757708
epoch 19, w1 -0.440698, w2 1195.260792
epoch 20, w1 -0.387814, w2 -1673.365109
epoch 20, w1 -0.387814, w2 -1673.365109

One solution for the issue above is use momentum: $v_t = \beta v_{t-1} + \partial_wl(x_t, w)$ where $v_0=0$ and $\beta \in (0,1)$. and then we update $w_t$ by $w_{t-1} - \eta_t v_t\$.

In [ ]:
def gd_2d(_gd, epoches=20, eta=0.1, output_det=0, beta=0.5):
    """Optimize a 2-dim objective function with a customized trainer."""
    w1, w2, s1, s2 = -5, -2, 0, 0
    results = [(w1, w2)]
    for i in range(epoches):
        w1, w2, s1, s2 = _gd(w1, w2, s1, s2, eta, beta)
        results.append((w1, w2))
        if output_det>0:
            print('epoch %d, w1 %f, w2 %f' % (i + 1, w1, w2))
    if output_det==0: print('epoch %d, w1 %f, w2 %f' % (i + 1, w1, w2))
    return results

def show_trace_2d(l, results):
    """Show the trace of 2D variables during optimization."""
    d2l.set_figsize((3.5, 2.5))
    d2l.plt.plot(*zip(*results), '-o', color='#ff7f0e')
    w1, w2 = np.meshgrid(np.arange(-5.5, 1.0, 0.1), np.arange(-3.0, 1.0, 0.1))
    d2l.plt.contour(w1, w2, l(w1, w2), colors='#1f77b4')
    d2l.plt.xlabel('w1')
    d2l.plt.ylabel('w2')
    
def momentum_2_d(w1, w2, v1, v2, eta=0.8, beta=0.5):
    v1 = beta * v1 + 0.2 * w1
    v2 = beta * v2 + 4 * w2
    return w1 - eta * v1, w2 - eta * v2, v1, v2
In [ ]:
show_trace_2d(l_2i, gd_2d(momentum_2_d,20,0.6,1))
epoch 1, w1 -4.400000, w2 2.800000
epoch 2, w1 -3.572000, w2 -1.520000
epoch 3, w1 -2.729360, w2 -0.032000
epoch 4, w1 -1.980517, w2 0.788800
epoch 5, w1 -1.368433, w2 -0.693920
epoch 6, w1 -0.898179, w2 0.230128
epoch 7, w1 -0.555271, w2 0.139845
epoch 8, w1 -0.317184, w2 -0.240924
epoch 9, w1 -0.160079, w2 0.146909
epoch 10, w1 -0.062317, w2 -0.011756
epoch 11, w1 -0.005957, w2 -0.062874
epoch 12, w1 0.022937, w2 0.062465
epoch 13, w1 0.034632, w2 -0.024781
epoch 14, w1 0.036323, w2 -0.008929
epoch 15, w1 0.032810, w2 0.020427
epoch 16, w1 0.027117, w2 -0.013920
epoch 17, w1 0.021016, w2 0.002314
epoch 18, w1 0.015443, w2 0.004877
epoch 19, w1 0.010804, w2 -0.005546
epoch 20, w1 0.007188, w2 0.002553
In [ ]:
show_trace_2d(l_2i, gd_2d(momentum_2_d,20,0.6,1,0.25))
epoch 1, w1 -4.400000, w2 2.800000
epoch 2, w1 -3.722000, w2 -2.720000
epoch 3, w1 -3.105860, w2 2.428000
epoch 4, w1 -2.579122, w2 -2.112200
epoch 5, w1 -2.137943, w2 1.822030
epoch 6, w1 -1.771095, w2 -1.567284
epoch 7, w1 -1.466851, w2 1.346870
epoch 8, w1 -1.214768, w2 -1.157079
epoch 9, w1 -1.005975, w2 0.993923
epoch 10, w1 -0.833060, w2 -0.853742
epoch 11, w1 -0.689864, w2 0.733323
epoch 12, w1 -0.571281, w2 -0.629886
epoch 13, w1 -0.473082, w2 0.541038
epoch 14, w1 -0.391762, w2 -0.464722
epoch 15, w1 -0.324421, w2 0.399171
epoch 16, w1 -0.268655, w2 -0.342866
epoch 17, w1 -0.222475, w2 0.294503
epoch 18, w1 -0.184233, w2 -0.252962
epoch 19, w1 -0.152564, w2 0.217281
epoch 20, w1 -0.126340, w2 -0.186632

4. Adagrad by [Duchi et al., 2011]¶

Similar to Momentum, Adagrad will also adjust gradient for each variable during learning process. But Adagrad will try to adjust the learning rate based on the squares of previously observed gradients.

$g_t = \Delta_w l(x_t, w)$

$s_t = s_{t-1} + g_t^2$

$w_t = w_{t-1} - \frac{\beta}{\sqrt{s_t+\epsilon}}g_t$

Where $S_0=0$ and $\epsilon$ is a smaller positive value.

We will still take $l(w) = 0.1 x_1^2 + 2x_2^2$.

In [ ]:
def adagrad_2d(w1, w2, s1, s2, eta=0.4, eps=1e-6):
    g1, g2 = 0.2 * w1, 4 * w2
    s1 += g1 ** 2
    s2 += g2 ** 2
    w1 -= eta / math.sqrt(s1 + eps) * g1
    w2 -= eta / math.sqrt(s2 + eps) * g2
    return w1, w2, s1, s2
In [ ]:
show_trace_2d(l_2i, gd_2d(adagrad_2d,20,0.4,1,1e-6))
epoch 1, w1 -4.600000, w2 -1.600000
epoch 2, w1 -4.329178, w2 -1.350122
epoch 3, w1 -4.114228, w2 -1.163597
epoch 4, w1 -3.932302, w2 -1.014436
epoch 5, w1 -3.772835, w2 -0.890767
epoch 6, w1 -3.629933, w2 -0.785968
epoch 7, w1 -3.499909, w2 -0.695875
epoch 8, w1 -3.380281, w2 -0.617648
epoch 9, w1 -3.269280, w2 -0.549239
epoch 10, w1 -3.165593, w2 -0.489098
epoch 11, w1 -3.068216, w2 -0.436016
epoch 12, w1 -2.976356, w2 -0.389023
epoch 13, w1 -2.889378, w2 -0.347323
epoch 14, w1 -2.806763, w2 -0.310253
epoch 15, w1 -2.728078, w2 -0.277253
epoch 16, w1 -2.652960, w2 -0.247842
epoch 17, w1 -2.581099, w2 -0.221608
epoch 18, w1 -2.512228, w2 -0.198191
epoch 19, w1 -2.446117, w2 -0.177277
epoch 20, w1 -2.382563, w2 -0.158591
In [ ]:
show_trace_2d(l_2i, gd_2d(adagrad_2d,20,2,1,1e-6))
epoch 1, w1 -3.000001, w2 -0.000000
epoch 2, w1 -1.971010, w2 -0.000000
epoch 3, w1 -1.330559, w2 -0.000000
epoch 4, w1 -0.907975, w2 -0.000000
epoch 5, w1 -0.622554, w2 -0.000000
epoch 6, w1 -0.427785, w2 -0.000000
epoch 7, w1 -0.294250, w2 -0.000000
epoch 8, w1 -0.202494, w2 -0.000000
epoch 9, w1 -0.139383, w2 -0.000000
epoch 10, w1 -0.095951, w2 -0.000000
epoch 11, w1 -0.066056, w2 -0.000000
epoch 12, w1 -0.045477, w2 -0.000000
epoch 13, w1 -0.031309, w2 -0.000000
epoch 14, w1 -0.021555, w2 -0.000000
epoch 15, w1 -0.014840, w2 -0.000000
epoch 16, w1 -0.010217, w2 -0.000000
epoch 17, w1 -0.007034, w2 -0.000000
epoch 18, w1 -0.004843, w2 -0.000000
epoch 19, w1 -0.003334, w2 -0.000000
epoch 20, w1 -0.002295, w2 -0.000000

5. RMSProp¶

The issue of Adagrad is that the learning rate decreases at a predefined schedule. In Adagrad, $s_t = s_{t-1} + g_t^2$ and $s_t$ keeps on growing without bound. [Tieleman & Hinton, 2012] proposed RMSProp as a simple fix to fixing the issue is to introduce $\gamma$, and $s_t = \gamma s_t + (1-\gamma)g_t^2$ where $\gamma > 0$.

We will still take the solution for $l(w) = 0.1w_1^2 + 2x_2^2$ for example.

In [ ]:
def rmsprop_2d(w1, w2, s1, s2, eta=0.4, gamma=0.9):
    g1, g2, eps = 0.2 * w1, 4 * w2, 1e-6
    s1 = gamma * s1 + (1 - gamma) * g1 ** 2
    s2 = gamma * s2 + (1 - gamma) * g2 ** 2
    w1 -= eta / math.sqrt(s1 + eps) * g1
    w2 -= eta / math.sqrt(s2 + eps) * g2
    return w1, w2, s1, s2
In [ ]:
show_trace_2d(l_2i, gd_2d(rmsprop_2d,20,0.4,1,0.9))
epoch 1, w1 -3.735095, w2 -0.735089
epoch 2, w1 -2.952557, w2 -0.278126
epoch 3, w1 -2.372981, w2 -0.097741
epoch 4, w1 -1.915252, w2 -0.031013
epoch 5, w1 -1.543071, w2 -0.008699
epoch 6, w1 -1.236422, w2 -0.002101
epoch 7, w1 -0.982686, w2 -0.000421
epoch 8, w1 -0.773052, w2 -0.000066
epoch 9, w1 -0.600837, w2 -0.000007
epoch 10, w1 -0.460616, w2 -0.000000
epoch 11, w1 -0.347757, w2 -0.000000
epoch 12, w1 -0.258167, w2 0.000000
epoch 13, w1 -0.188167, w2 -0.000000
epoch 14, w1 -0.134436, w2 0.000000
epoch 15, w1 -0.093992, w2 -0.000000
epoch 16, w1 -0.064194, w2 0.000000
epoch 17, w1 -0.042745, w2 -0.000000
epoch 18, w1 -0.027691, w2 0.000000
epoch 19, w1 -0.017412, w2 -0.000000
epoch 20, w1 -0.010599, w2 0.000000
In [ ]:
show_trace_2d(l_2i, gd_2d(rmsprop_2d,20,0.1,1,0.9))
epoch 1, w1 -4.683774, w2 -1.683772
epoch 2, w1 -4.461587, w2 -1.473875
epoch 3, w1 -4.279291, w2 -1.308718
epoch 4, w1 -4.120057, w2 -1.169840
epoch 5, w1 -3.976157, w2 -1.048927
epoch 6, w1 -3.843313, w2 -0.941451
epoch 7, w1 -3.718881, w2 -0.844649
epoch 8, w1 -3.601097, w2 -0.756714
epoch 9, w1 -3.488721, w2 -0.676396
epoch 10, w1 -3.380846, w2 -0.602797
epoch 11, w1 -3.276789, w2 -0.535254
epoch 12, w1 -3.176022, w2 -0.473262
epoch 13, w1 -3.078127, w2 -0.416425
epoch 14, w1 -2.982772, w2 -0.364427
epoch 15, w1 -2.889688, w2 -0.317002
epoch 16, w1 -2.798656, w2 -0.273923
epoch 17, w1 -2.709492, w2 -0.234983
epoch 18, w1 -2.622048, w2 -0.199989
epoch 19, w1 -2.536198, w2 -0.168748
epoch 20, w1 -2.451840, w2 -0.141068
In [ ]:
show_trace_2d(l_2i, gd_2d(rmsprop_2d,20,1,1,0.9))
epoch 1, w1 -1.837738, w2 1.162277
epoch 2, w1 -0.695328, w2 -0.489563
epoch 3, w1 -0.244360, w2 0.224882
epoch 4, w1 -0.077536, w2 -0.119001
epoch 5, w1 -0.021748, w2 0.072463
epoch 6, w1 -0.005254, w2 -0.050339
epoch 7, w1 -0.001054, w2 0.039548
epoch 8, w1 -0.000166, w2 -0.034869
epoch 9, w1 -0.000019, w2 0.034277
epoch 10, w1 -0.000001, w2 -0.037353
epoch 11, w1 -0.000000, w2 0.044899
epoch 12, w1 0.000000, w2 -0.059262
epoch 13, w1 -0.000000, w2 0.085504
epoch 14, w1 0.000000, w2 -0.134132
epoch 15, w1 -0.000000, w2 0.226683
epoch 16, w1 0.000000, w2 -0.403198
epoch 17, w1 -0.000000, w2 0.703135
epoch 18, w1 0.000000, w2 -1.007366
epoch 19, w1 -0.000000, w2 0.993176
epoch 20, w1 0.000000, w2 -0.744051
In [ ]:
show_trace_2d(l_2i, gd_2d(rmsprop_2d,20,0.4,1,0.5))
epoch 1, w1 -4.434315, w2 -1.434315
epoch 2, w1 -3.992010, w2 -1.031502
epoch 3, w1 -3.592920, w2 -0.699698
epoch 4, w1 -3.214966, w2 -0.422297
epoch 5, w1 -2.849734, w2 -0.203885
epoch 6, w1 -2.493852, w2 -0.059684
epoch 7, w1 -2.146327, w2 -0.000316
epoch 8, w1 -1.807572, w2 0.000129
epoch 9, w1 -1.479098, w2 -0.000127
epoch 10, w1 -1.163595, w2 0.000231
epoch 11, w1 -0.865337, w2 -0.000687
epoch 12, w1 -0.591009, w2 0.003179
epoch 13, w1 -0.351060, w2 -0.022093
epoch 14, w1 -0.161189, w2 0.205325
epoch 15, w1 -0.040729, w2 -0.350494
epoch 16, w1 0.002190, w2 0.170762
epoch 17, w1 -0.001073, w2 -0.132441
epoch 18, w1 0.001188, w2 0.154252
epoch 19, w1 -0.002351, w2 -0.208258
epoch 20, w1 0.007544, w2 0.229752
In [ ]:
show_trace_2d(l_2i, gd_2d(rmsprop_2d,20,0.4,1,0.7))
epoch 1, w1 -4.269704, w2 -1.269703
epoch 2, w1 -3.748056, w2 -0.828258
epoch 3, w1 -3.310087, w2 -0.516918
epoch 4, w1 -2.919471, w2 -0.295598
epoch 5, w1 -2.560782, w2 -0.147472
epoch 6, w1 -2.226457, w2 -0.059785
epoch 7, w1 -1.912725, w2 -0.017368
epoch 8, w1 -1.617984, w2 -0.002643
epoch 9, w1 -1.342072, w2 0.000035
epoch 10, w1 -1.085911, w2 -0.000007
epoch 11, w1 -0.851310, w2 0.000003
epoch 12, w1 -0.640815, w2 -0.000002
epoch 13, w1 -0.457498, w2 0.000003
epoch 14, w1 -0.304541, w2 -0.000004
epoch 15, w1 -0.184501, w2 0.000007
epoch 16, w1 -0.098188, w2 -0.000019
epoch 17, w1 -0.043441, w2 0.000061
epoch 18, w1 -0.014514, w2 -0.000246
epoch 19, w1 -0.002964, w2 0.001239
epoch 20, w1 -0.000145, w2 -0.007692
In [ ]:
show_trace_2d(l_2i, gd_2d(rmsprop_2d,100,0.4,1,0.7))
epoch 1, w1 -4.269704, w2 -1.269703
epoch 2, w1 -3.748056, w2 -0.828258
epoch 3, w1 -3.310087, w2 -0.516918
epoch 4, w1 -2.919471, w2 -0.295598
epoch 5, w1 -2.560782, w2 -0.147472
epoch 6, w1 -2.226457, w2 -0.059785
epoch 7, w1 -1.912725, w2 -0.017368
epoch 8, w1 -1.617984, w2 -0.002643
epoch 9, w1 -1.342072, w2 0.000035
epoch 10, w1 -1.085911, w2 -0.000007
epoch 11, w1 -0.851310, w2 0.000003
epoch 12, w1 -0.640815, w2 -0.000002
epoch 13, w1 -0.457498, w2 0.000003
epoch 14, w1 -0.304541, w2 -0.000004
epoch 15, w1 -0.184501, w2 0.000007
epoch 16, w1 -0.098188, w2 -0.000019
epoch 17, w1 -0.043441, w2 0.000061
epoch 18, w1 -0.014514, w2 -0.000246
epoch 19, w1 -0.002964, w2 0.001239
epoch 20, w1 -0.000145, w2 -0.007692
epoch 21, w1 0.000020, w2 0.058281
epoch 22, w1 -0.000007, w2 -0.404153
epoch 23, w1 0.000004, w2 0.313238
epoch 24, w1 -0.000004, w2 -0.178278
epoch 25, w1 0.000006, w2 0.125732
epoch 26, w1 -0.000010, w2 -0.116077
epoch 27, w1 0.000023, w2 0.134542
epoch 28, w1 -0.000067, w2 -0.179022
epoch 29, w1 0.000248, w2 0.232807
epoch 30, w1 -0.001153, w2 -0.248567
epoch 31, w1 0.006619, w2 0.221536
epoch 32, w1 -0.046477, w2 -0.191469
epoch 33, w1 0.333333, w2 0.176914
epoch 34, w1 -0.379181, w2 -0.178490
epoch 35, w1 0.203944, w2 0.191136
epoch 36, w1 -0.129532, w2 -0.205920
epoch 37, w1 0.109638, w2 0.212917
epoch 38, w1 -0.120020, w2 -0.209382
epoch 39, w1 0.157828, w2 0.201137
epoch 40, w1 -0.216927, w2 -0.194885
epoch 41, w1 0.253725, w2 0.193503
epoch 42, w1 -0.235060, w2 -0.196277
epoch 43, w1 0.199739, w2 0.200464
epoch 44, w1 -0.178112, w2 -0.203157
epoch 45, w1 0.174508, w2 0.203109
epoch 46, w1 -0.184908, w2 -0.201180
epoch 47, w1 0.201346, w2 0.199134
epoch 48, w1 -0.212696, w2 -0.198225
epoch 49, w1 0.212349, w2 0.198641
epoch 50, w1 -0.204198, w2 -0.199744
epoch 51, w1 0.196156, w2 0.200673
epoch 52, w1 -0.192831, w2 -0.200919
epoch 53, w1 0.194562, w2 0.200536
epoch 54, w1 -0.198944, w2 -0.199947
epoch 55, w1 0.202671, w2 0.199567
epoch 56, w1 -0.203623, w2 -0.199560
epoch 57, w1 0.202028, w2 0.199818
epoch 58, w1 -0.199661, w2 -0.200109
epoch 59, w1 0.198183, w2 0.200246
epoch 60, w1 -0.198189, w2 -0.200195
epoch 61, w1 0.199231, w2 0.200042
epoch 62, w1 -0.200388, w2 -0.199909
epoch 63, w1 0.200929, w2 0.199871
epoch 64, w1 -0.200709, w2 -0.199921
epoch 65, w1 0.200091, w2 0.200004
epoch 66, w1 -0.199568, w2 -0.200059
epoch 67, w1 0.199423, w2 0.200062
epoch 68, w1 -0.199629, w2 -0.200027
epoch 69, w1 0.199958, w2 0.199986
epoch 70, w1 -0.200177, w2 -0.199966
epoch 71, w1 0.200186, w2 0.199972
epoch 72, w1 -0.200043, w2 -0.199993
epoch 73, w1 0.199880, w2 0.200012
epoch 74, w1 -0.199800, w2 -0.200018
epoch 75, w1 0.199827, w2 0.200011
epoch 76, w1 -0.199912, w2 -0.200000
epoch 77, w1 0.199987, w2 0.199992
epoch 78, w1 -0.200010, w2 -0.199991
epoch 79, w1 0.199983, w2 0.199996
epoch 80, w1 -0.199937, w2 -0.200002
epoch 81, w1 0.199905, w2 0.200005
epoch 82, w1 -0.199902, w2 -0.200004
epoch 83, w1 0.199922, w2 0.200001
epoch 84, w1 -0.199945, w2 -0.199998
epoch 85, w1 0.199956, w2 0.199997
epoch 86, w1 -0.199953, w2 -0.199998
epoch 87, w1 0.199942, w2 0.200000
epoch 88, w1 -0.199931, w2 -0.200001
epoch 89, w1 0.199927, w2 0.200001
epoch 90, w1 -0.199931, w2 -0.200000
epoch 91, w1 0.199937, w2 0.200000
epoch 92, w1 -0.199942, w2 -0.199999
epoch 93, w1 0.199942, w2 0.199999
epoch 94, w1 -0.199940, w2 -0.200000
epoch 95, w1 0.199937, w2 0.200000
epoch 96, w1 -0.199935, w2 -0.200000
epoch 97, w1 0.199935, w2 0.200000
epoch 98, w1 -0.199937, w2 -0.200000
epoch 99, w1 0.199938, w2 0.200000
epoch 100, w1 -0.199939, w2 -0.200000

6 Adam¶

Util now, we have learned gradient descent, momentum, adagrad and RMSProp for optimization of loss function. In 2014, Kingma & Ba proposed the Adam algorithm combing all the techniques into an efficient learning algorithm. However, this does not mean that there is no issue in Adam. In some cases, Adam cannot converge. but it is still a popular method.

$v_t \leftarrow \beta_1 v_{t-1} + (1-\beta_1)g_t$

$s_t \leftarrow \beta_2 v_{t-1} + (1-\beta_2)g_t^2$

$w_t = w_{t-1} - \frac{\eta v_t}{\sqrt{s_t + \epsilon}}$

Common choices are $\beta_1=0.9$ and $\beta_2=0.999$

In [ ]:
def adam_2d(w1, w2, s1, s2, v1, v2, lr=0.01, t=1):
    g1, g2, eps = 0.2 * w1, 4 * w2, 1e-6
    beta1, beta2 = 0.9, 0.999;
    s1 = beta2 * s1 + (1 - beta2) * g1 ** 2
    s2 = beta2 * s2 + (1 - beta2) * g2 ** 2
    s1_corr = s1/(1-beta2 ** t)
    s2_corr = s2/(1-beta2 ** t)
    
    v1 = beta1 * v1 + (1 - beta1) * g1
    v2 = beta1 * v2 + (1 - beta1) * g2
    v1_corr = v1/(1-beta1 ** t)
    v2_corr = v2/(1-beta1 ** t)
    
    w1 -= lr * v1_corr / (np.sqrt(s1_corr + eps))
    w2 -= lr * v2_corr / (np.sqrt(s2_corr + eps))
    return w1, w2, s1, s2, v1, v2

def gd_2d_adam(_gd, epoches=10, lr=0.01, output_det=0, t=1):
    """Optimize a 2-dim objective function with a customized trainer."""
    w1, w2, s1, s2, v1, v2 = -5, -2, 0, 0, 0, 0
    results = [(w1, w2)]
    for i in range(epoches):
        w1, w2, s1, s2, v1, v2 = _gd(w1, w2, s1, s2, v1, v2, lr, t)
        t += 1;
        results.append((w1, w2))
        if output_det>0 and (epoches<=20 or (epoches>20 and i%10==0)):
            print('epoch %d, w1 %f, w2 %f' % (i + 1, w1, w2))
    if output_det==0: print('epoch %d, w1 %f, w2 %f' % (i + 1, w1, w2))
    return results
In [ ]:
show_trace_2d(l_2i, gd_2d_adam(adam_2d,20,0.1,1,1))
epoch 1, w1 -4.900000, w2 -1.900000
epoch 2, w1 -4.800058, w2 -1.800166
epoch 3, w1 -4.700213, w2 -1.700623
epoch 4, w1 -4.600508, w2 -1.601505
epoch 5, w1 -4.500983, w2 -1.502956
epoch 6, w1 -4.401682, w2 -1.405132
epoch 7, w1 -4.302649, w2 -1.308200
epoch 8, w1 -4.203928, w2 -1.212337
epoch 9, w1 -4.105564, w2 -1.117733
epoch 10, w1 -4.007602, w2 -1.024587
epoch 11, w1 -3.910089, w2 -0.933108
epoch 12, w1 -3.813072, w2 -0.843515
epoch 13, w1 -3.716596, w2 -0.756035
epoch 14, w1 -3.620708, w2 -0.670898
epoch 15, w1 -3.525455, w2 -0.588342
epoch 16, w1 -3.430884, w2 -0.508604
epoch 17, w1 -3.337041, w2 -0.431920
epoch 18, w1 -3.243972, w2 -0.358521
epoch 19, w1 -3.151724, w2 -0.288628
epoch 20, w1 -3.060340, w2 -0.222452
In [ ]:
show_trace_2d(l_2i, gd_2d_adam(adam_2d,100,0.1,1,1))
epoch 1, w1 -4.900000, w2 -1.900000
epoch 11, w1 -3.910089, w2 -0.933108
epoch 21, w1 -2.969866, w2 -0.160186
epoch 31, w1 -2.124379, w2 0.224022
epoch 41, w1 -1.411153, w2 0.221906
epoch 51, w1 -0.852443, w2 0.059513
epoch 61, w1 -0.451171, w2 -0.048372
epoch 71, w1 -0.191737, w2 -0.047235
epoch 81, w1 -0.045545, w2 -0.004101
epoch 91, w1 0.021049, w2 0.015530
In [ ]:
show_trace_2d(l_2i, gd_2d_adam(adam_2d,200,0.1,1,1))
epoch 1, w1 -4.900000, w2 -1.900000
epoch 11, w1 -3.910089, w2 -0.933108
epoch 21, w1 -2.969866, w2 -0.160186
epoch 31, w1 -2.124379, w2 0.224022
epoch 41, w1 -1.411153, w2 0.221906
epoch 51, w1 -0.852443, w2 0.059513
epoch 61, w1 -0.451171, w2 -0.048372
epoch 71, w1 -0.191737, w2 -0.047235
epoch 81, w1 -0.045545, w2 -0.004101
epoch 91, w1 0.021049, w2 0.015530
epoch 101, w1 0.039438, w2 0.007070
epoch 111, w1 0.034154, w2 -0.003383
epoch 121, w1 0.021345, w2 -0.003451
epoch 131, w1 0.009604, w2 0.000421
epoch 141, w1 0.002038, w2 0.001312
epoch 151, w1 -0.001421, w2 0.000051
epoch 161, w1 -0.002141, w2 -0.000460
epoch 171, w1 -0.001580, w2 -0.000057
epoch 181, w1 -0.000757, w2 0.000160
epoch 191, w1 -0.000155, w2 0.000023
In [ ]: